//Change code to generate the matrix print it and multiply it.
import java.util.*;
import java.io.*;
import static java.lang.System.*;

import mpij.*;
import mpij.msg.*;
import mpij.msg.op.*;

class Divider
	{
	private Block[] m_blocks;
	
	private class Block
		{
		public int m_start;
		public int m_end;
		public Block(int start, int end)
			{
			m_start = start;
			m_end = end;
			}
		}
		
	//---------------------------------------------------------------------------
	public Divider(int matrixSz, int nproc)
		{
		//Only works with power of 2 procs
		int blockArraySz = (int)Math.sqrt(nproc);
		int baseBlockSz = matrixSz / blockArraySz;
		
		m_blocks = new Block[blockArraySz];
		
		int start = 0;
		for (int I = 0; I < blockArraySz; I++)
			{
			m_blocks[I] = new Block(start, (start + baseBlockSz -1));
			start += baseBlockSz;
			}
		}
		
	//---------------------------------------------------------------------------
	public int getStart(int index)
		{
		return (m_blocks[index].m_start);
		}
		
	//---------------------------------------------------------------------------
	public int getEnd(int index)
		{
		return (m_blocks[index].m_end);
		}
		
	//---------------------------------------------------------------------------
	public int getBlockSize()
		{
		return (m_blocks.length);
		}
	}

//==============================================================================
class Matrix
	{
	private long[][] m_matrix;
	private  Random m_random;
	
	public Matrix(int size)
		{
		m_matrix = new long[size][size];
		m_random = new Random();
		}
	
	//---------------------------------------------------------------------------
	private long readLong(Reader reader)
			throws IOException
		{
		int ch = 0;
		StringBuilder sb = new StringBuilder();
		
		while (((ch = reader.read()) != -1) && (ch != ' '))
			sb.append((char)ch);
			
		return (Long.parseLong(sb.toString()));
		}
		
	//---------------------------------------------------------------------------
	public int getSize()
		{
		return (m_matrix.length);
		}
	
	//---------------------------------------------------------------------------
	public long[][] getMatrix()
		{
		return (m_matrix);
		}
		
	//---------------------------------------------------------------------------
	public void readMatrix(String file)
			throws IOException
		{
		FileReader fr = new FileReader(file);
		int size = m_matrix.length;
		
		for (int I = 0; I < size; I++)
			for (int J = 0; J < size; J++)
				{
				m_matrix[I][J] = readLong(fr);
				}
		}
		
	//---------------------------------------------------------------------------
	public void print()
		{
		int size = m_matrix.length;
		
		for (int I = 0; I < size; I++)
			{
			for (int J = 0; J < size; J++)
				System.out.print(m_matrix[I][J]+" ");
			System.out.println();
			}
		}
		
	//---------------------------------------------------------------------------
	public void generate()
		{
		int size = m_matrix.length;
		
		for (int I = 0; I < size; I++)
			for (int J = 0; J < size; J++)
				m_matrix[I][J] = Math.abs(m_random.nextInt(1000));
		}
		
	//---------------------------------------------------------------------------
	private int shiftIndex(int row, int col, int size)
		{
		return ((row + col) % size);
		}
		
	//---------------------------------------------------------------------------
	public Matrix getSubMatrix(int iproc, Divider divider, char matrix)
		{
		int blockRow = iproc / divider.getBlockSize();
		int blockCol = iproc % divider.getBlockSize();
		
		switch (matrix)
			{
			case 'a':
				blockRow = shiftIndex(blockRow, blockCol, divider.getBlockSize());
				break;
			case 'b':
				blockCol = shiftIndex(blockRow, blockCol, divider.getBlockSize());
				break;
			case 'c':
				break;
			}
			
		int width = divider.getEnd(blockCol) - divider.getStart(blockCol) +1;
		
		Matrix newMatrix = new Matrix(width);
		int myRow = 0;
		int myCol = 0;
		int endRow = divider.getEnd(blockRow);
		int endCol = divider.getEnd(blockCol);
		
		for (int I = divider.getStart(blockRow); I <= endRow; I++)
			{
			for (int J = divider.getStart(blockCol); J <= endCol; J++)
				{
				newMatrix.m_matrix[myRow][myCol] = m_matrix[I][J];
				myCol++;
				}
			myRow ++;
			myCol = 0;
			}
			
		return (newMatrix);
		}
		
	//---------------------------------------------------------------------------
	public void replaceSubMatrix(Matrix matrix, Divider div, int iproc)
		{
		int blockRow = iproc / div.getBlockSize();
		int blockCol = iproc % div.getBlockSize();
		
		int subRow = 0;
		int subCol = 0;
		int endRow = div.getEnd(blockRow);
		int endCol = div.getEnd(blockCol);
		
		for (int I = div.getStart(blockRow); I <= endRow; I++)
			{
			for (int J = div.getStart(blockCol); J <= endCol; J++)
				{
				m_matrix[I][J] = matrix.m_matrix[subRow][subCol];
				subCol ++;
				}
				
			subRow ++;
			subCol = 0;
			}
		}
		
	//---------------------------------------------------------------------------
	public void multiply(Matrix mat, Matrix ans)
		{
		int size = m_matrix.length;
		
		long[][] a = m_matrix;
		long[][] b = mat.m_matrix;
		long[][] c = ans.m_matrix;
		
		for (int I = 0; I < size; I++)
			for (int J = 0; J < size; J++)
				for (int K = 0; K < size; K++)
					c[I][J] += a[I][K] * b[K][J];
		}
		
	//---------------------------------------------------------------------------
	public void writeToFile(String file)
		{
		try
			{
			PrintWriter pw = new PrintWriter(new FileWriter(file));
			int size = m_matrix.length;
			
			for (int I = 0; I < size; I++)
				for (int J = 0; J < size; J++)
					pw.print(m_matrix[I][J] + " ");
					
			pw.close();
			}
		catch (IOException ioe)
			{
			out.println(ioe);
			}
		}
	}
	
//==============================================================================
public class MatrixMult
		implements MPIRunnable
	{
	private static int MATRIX_SIZE = 1200;
	private Communicator m_commWorld;
	private int m_nproc;
	private int m_iproc;
	
	public static void main(String[] args)
		{
		int size = Integer.parseInt(args[0]);
		Matrix matA, matB, ans;
		
		out.println("Generating Matrix A");
		matA = new Matrix(size);
		matA.generate();
		out.println("Writing Matrix A to file");
		matA.writeToFile("amat.txt");
		
		out.println("Generating Matrix B");
		matB = new Matrix(size);
		matB.generate();
		out.println("Writing Matrix B to file");
		matB.writeToFile("bmat.txt");
		
		out.println("Multiplying Matrix A and B");
		ans = new Matrix(size);
		matA.multiply(matB, ans);
		out.println("Writing answer");
		ans.writeToFile("ans.txt");
		
		}
		
	public MatrixMult()
		{
		}
		
//------------------------------------------------------------------------------
	public static int getTopNeighbor(int iproc, int size)
		{
		int row = iproc / size;
		int col = iproc - (row * size);
		
		row = (row + size -1) % size;
		
		return (col + (size * row));
		}
	
//------------------------------------------------------------------------------
	public static int getBottomNeighbor(int iproc, int size)
		{
		int row = iproc / size;
		int col = iproc - (row * size);
		
		row = (row +1) % size;
		
		return (col + (size * row));
		}
	
//------------------------------------------------------------------------------
	public static int getLeftNeighbor(int iproc, int size)
		{
		int row = iproc / size;
		int col = iproc - (row * size);
		
		col = (col + size -1) % size;
		
		return (col + (size * row));
		}
	
//------------------------------------------------------------------------------
	public static int getRightNeighbor(int iproc, int size)
		{
		int row = iproc / size;
		int col = iproc - (row * size);
		
		col = (col+1) % size;
		
		return (col + (size * row));
		}
		
	public void run(Communicator commWorld)
		{
		long startTime = 0;
		long endTime = 0;
		
		try
			{
			m_commWorld = commWorld;
			m_nproc = m_commWorld.getSize();
			m_iproc = m_commWorld.getRank();
			System.out.println(m_iproc +" Started");
			
			Divider div = new Divider(MATRIX_SIZE, m_nproc);
			
			Matrix matA = new Matrix(MATRIX_SIZE);
			//System.out.println("reading mat a");
			matA.readMatrix("/ibrix/home/beh54/matrix/amat.txt");
			//matA.readMatrix("amat.txt");
			Matrix myAMatrix = matA.getSubMatrix(m_iproc, div, 'a');
			matA = null;
			
			Matrix matB = new Matrix(MATRIX_SIZE);
			//System.out.println("reading mat b");
			matB.readMatrix("/ibrix/home/beh54/matrix/bmat.txt");
			//matB.readMatrix("bmat.txt");
			Matrix myBMatrix = matB.getSubMatrix(m_iproc, div, 'b');
			matB = null;
			
			Matrix myCMatrix = new Matrix(myAMatrix.getSize());
			
			startTime = System.currentTimeMillis();
			
			System.out.println("Multiplying on "+m_iproc);
			myAMatrix.multiply(myBMatrix, myCMatrix);
			
			int blockSz = div.getBlockSize();
			int remoteNode;
			
			//System.out.println("Allocating messages");
			Message aMsg = new Long2DMessage(myAMatrix.getMatrix());
			Message bMsg = new Long2DMessage(myBMatrix.getMatrix());
			for (int I = 0; I < blockSz -1; I++)
				{
				//System.out.println("Sending");
				//Send B matrix down
				remoteNode = getBottomNeighbor(m_iproc, blockSz);
				m_commWorld.send(bMsg, remoteNode, 0);
				
				//Send A matrix right
				remoteNode = getRightNeighbor(m_iproc, blockSz);
				m_commWorld.send(aMsg, remoteNode, 0);
				
				//Receive B matrix
				remoteNode = getTopNeighbor(m_iproc, blockSz);
				m_commWorld.recv(bMsg, remoteNode, 0);
				
				//Receive A matrix
				remoteNode = getLeftNeighbor(m_iproc, blockSz);
				m_commWorld.recv(aMsg, remoteNode, 0);
				
				
				//Multiply matrix A and B
				myAMatrix.multiply(myBMatrix, myCMatrix);
				}
				
			if (m_iproc == 0)
				{
				Matrix ans = new Matrix(MATRIX_SIZE);
				ans.replaceSubMatrix(myCMatrix, div, 0);
				Message cMsg = new Long2DMessage(myCMatrix.getMatrix());
				
				for (int I = 1; I < m_nproc; I++)
					{
					//System.out.println("Receive from "+I);
					m_commWorld.recv(cMsg, I, 0);
					
					//System.out.println("Inserting "+I);
					ans.replaceSubMatrix(myCMatrix, div, I);
					
					//System.out.println("Done with "+I);
					}
				}
			else
				{
				m_commWorld.send(new Long2DMessage(myCMatrix.getMatrix()), 0, 0);
				}
			
			endTime = System.currentTimeMillis();
			
			if (m_iproc == 0)
				{
				startTime = endTime - startTime;
				out.println("total "+((double)startTime)/1000+" sec");
				}
				
			//System.out.println(m_iproc +" Done");
			}
		catch (OutOfMemoryError oome)
			{
			System.out.println(oome);
			oome.printStackTrace();
			}
		catch (IOException ioe)
			{
			System.out.println(ioe);
			ioe.printStackTrace();
			}
		catch (MPIException mpie)
			{
			System.out.println(mpie);
			mpie.printStackTrace();
			}
		}
	}
	
